In [1]:
import shutil
import sys
sys.path.insert(0, './src')
from models.classifier import FoodClassifier
from project_utils import calculate_metrics
import plot
# Initialize the classification model
model = FoodClassifier()
# Model testing, predict a single image
example_image = './data/example.JPG'
results = model.predict_single(example_image)
plot.single_image(example_image, title=f"{results['label']}: {results['confidence']:.2f}")
Model initialized, using device: cuda
Run prediction on our collected dataset as the baseline
In [2]:
dataset_path = './data/raw'
results = model.predict_folder(dataset_path)
plot.wrong_predictions(results, folder = dataset_path)
metrics = calculate_metrics(**results)
print(metrics)
Loaded 205 images from ./data/raw Predicted 205 images in 8.03s
{'accuracy': 0.5658536585365853, 'f1_score': 0.6658446700550718}
Part II - Individual Algorithm Evaluation¶
Process our collected data with the image processing algorithms. Use them as the input to the model. Find which individual algorithm performs best
In [3]:
from preprocessing.lowlight import gamma, CLAHE, SSRetinex
from preprocessing.deblurr import ssk, usm, swf
from preprocessing.downscaling import Lanczos, Lanczos_SAID, DPID
from project_utils import preprocess_folder
processor = model.processor
processor.do_resize = True
processor.do_center_crop = False
output_path = './data/preprocessed'
# Low-light
for func in [gamma, CLAHE, SSRetinex]:
out_dir = output_path + func.__name__
preprocess_folder(func, input_dir=dataset_path, output_dir=out_dir)
results = model.predict_folder(out_dir)
metrics = calculate_metrics(**results)
print(f"Metrics for {func.__name__}: {metrics}")
plot.wrong_predictions(results, folder = out_dir)
# shutil.rmtree(output_path)
processor.do_resize = True
processor.do_center_crop = False
# #Deblurring
for func in [ssk, usm, swf]:
out_dir = output_path + func.__name__
preprocess_folder(func, input_dir=dataset_path, output_dir=out_dir)
results = model.predict_folder(out_dir)
metrics = calculate_metrics(**results)
print(f"Metrics for {func.__name__}: {metrics}")
plot.wrong_predictions(results, folder = out_dir)
# shutil.rmtree(output_path)
processor.do_resize = False
processor.do_center_crop = False
# Downscaling
for func in [Lanczos, Lanczos_SAID, DPID]:
out_dir = output_path + func.__name__
preprocess_folder(func, input_dir=dataset_path, output_dir=out_dir)
results = model.predict_folder(out_dir)
metrics = calculate_metrics(**results)
print(f"Metrics for {func.__name__}: {metrics}")
plot.wrong_predictions(results, folder = out_dir)
# shutil.rmtree(output_path)
Loaded 205 images from ./data/preprocessedgamma
Predicted 205 images in 16.12s
Metrics for gamma: {'accuracy': 0.5756097560975609, 'f1_score': 0.6678271677670021}
Loaded 205 images from ./data/preprocessedCLAHE
Predicted 205 images in 16.33s
Metrics for CLAHE: {'accuracy': 0.5756097560975609, 'f1_score': 0.6880086635284361}
Loaded 205 images from ./data/preprocessedSSRetinex
Predicted 205 images in 17.18s
Metrics for SSRetinex: {'accuracy': 0.2975609756097561, 'f1_score': 0.4352442714584815}
Loaded 205 images from ./data/preprocessedssk
Predicted 205 images in 16.88s
Metrics for ssk: {'accuracy': 0.551219512195122, 'f1_score': 0.6517573080750508}
Loaded 205 images from ./data/preprocessedusm
Predicted 205 images in 17.65s
Metrics for usm: {'accuracy': 0.551219512195122, 'f1_score': 0.6554786566545715}
Loaded 205 images from ./data/preprocessedswf
Predicted 205 images in 14.86s
Metrics for swf: {'accuracy': 0.45365853658536587, 'f1_score': 0.5325994826796637}
Loaded 205 images from ./data/preprocessedLanczos
Predicted 205 images in 6.62s
Metrics for Lanczos: {'accuracy': 0.5756097560975609, 'f1_score': 0.6702244717777451}
[INFO] Loading SAID model 'SAID_Lanczos' from c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\src\preprocessing\..\..\SAID\pretrained_models\SAID_Lanczos.pth
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
ckpt = torch.load(ckpt_path, map_location="cpu")
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
c:\anaconda250618\envs\ece253_env\Lib\site-packages\torch\functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at C:\cb\pytorch_1000000000000\work\aten\src\ATen\native\TensorShape.cpp:3596.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
Loaded 205 images from ./data/preprocessedLanczos_SAID
Predicted 205 images in 6.92s
Metrics for Lanczos_SAID: {'accuracy': 0.5560975609756098, 'f1_score': 0.6504567606954238}
Loaded 205 images from ./data/preprocessedDPID
Predicted 205 images in 6.48s
Metrics for DPID: {'accuracy': 0.5024390243902439, 'f1_score': 0.6202426763231333}
After finding the best algorithm of each distortion, combine them
In [5]:
preprocess_folder(CLAHE, input_dir=dataset_path, output_dir=output_path)
preprocess_folder(ssk, input_dir=output_path, output_dir=output_path)
# model.processor.do_resize = False
# model.processor.do_center_crop = False
preprocess_folder(Lanczos, input_dir=output_path, output_dir=output_path)
model.processor.do_resize = False
model.processor.do_center_crop = False
results = model.predict_folder(output_path)
metrics = calculate_metrics(**results)
print(f"Metrics for best combination: {metrics}")
plot.wrong_predictions(results, folder = output_path)
shutil.rmtree(output_path)
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.60s
Metrics for best combination: {'accuracy': 0.4292682926829268, 'f1_score': 0.5669924198829516}
Part III - Combination Search¶
Find the best low-light - deblur - downscale combination
In [7]:
def Identity(image):
return image
# find the best low-light - deblur - downscale combination
model.processor.do_resize = False
model.processor.do_center_crop = False
output_path = './data/preprocessed'
best_metrics = None
best_combination = None
for func1 in [Identity, gamma, CLAHE, SSRetinex]:
for func2 in [Identity, ssk, usm, swf]:
for func3 in [Lanczos, Lanczos_SAID, DPID]:
preprocess_folder(func1, input_dir=dataset_path, output_dir=output_path)
preprocess_folder(func2, input_dir=output_path, output_dir=output_path)
preprocess_folder(func3, input_dir=output_path, output_dir=output_path)
results = model.predict_folder(output_path)
metrics = calculate_metrics(**results)
print(f"Metrics for {func1.__name__} + {func2.__name__} + {func3.__name__}: {metrics}")
plot.wrong_predictions(results, folder = output_path)
shutil.rmtree(output_path)
if best_metrics is None or metrics['accuracy'] > best_metrics['accuracy']:
best_metrics = metrics
best_combination = (func1.__name__, func2.__name__, func3.__name__)
print(f"Best combination: {best_combination} with metrics: {best_metrics}")
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.40s
Metrics for Identity + Identity + Lanczos: {'accuracy': 0.5756097560975609, 'f1_score': 0.6702244717777451}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.15s
Metrics for Identity + Identity + Lanczos_SAID: {'accuracy': 0.5560975609756098, 'f1_score': 0.6504567606954238}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.47s
Metrics for Identity + Identity + DPID: {'accuracy': 0.5024390243902439, 'f1_score': 0.6202426763231333}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.38s
Metrics for Identity + ssk + Lanczos: {'accuracy': 0.5365853658536586, 'f1_score': 0.6369884345577189}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.95s
Metrics for Identity + ssk + Lanczos_SAID: {'accuracy': 0.5463414634146342, 'f1_score': 0.6551917573494846}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.46s
Metrics for Identity + ssk + DPID: {'accuracy': 0.32682926829268294, 'f1_score': 0.45701842377303065}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.26s
Metrics for Identity + usm + Lanczos: {'accuracy': 0.5463414634146342, 'f1_score': 0.6486484452115001}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.22s
Metrics for Identity + usm + Lanczos_SAID: {'accuracy': 0.526829268292683, 'f1_score': 0.6408684897536749}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.45s
Metrics for Identity + usm + DPID: {'accuracy': 0.3170731707317073, 'f1_score': 0.44920419785622817}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.81s
Metrics for Identity + swf + Lanczos: {'accuracy': 0.47804878048780486, 'f1_score': 0.5650475784109311}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.45s
Metrics for Identity + swf + Lanczos_SAID: {'accuracy': 0.4585365853658537, 'f1_score': 0.5483907095839314}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.89s
Metrics for Identity + swf + DPID: {'accuracy': 0.4146341463414634, 'f1_score': 0.5272544860998482}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.03s
Metrics for gamma + Identity + Lanczos: {'accuracy': 0.5853658536585366, 'f1_score': 0.6790981235782262}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.52s
Metrics for gamma + Identity + Lanczos_SAID: {'accuracy': 0.5902439024390244, 'f1_score': 0.670396574973935}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.75s
Metrics for gamma + Identity + DPID: {'accuracy': 0.5317073170731708, 'f1_score': 0.648310872817641}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.03s
Metrics for gamma + ssk + Lanczos: {'accuracy': 0.5414634146341464, 'f1_score': 0.6413556762119568}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.58s
Metrics for gamma + ssk + Lanczos_SAID: {'accuracy': 0.551219512195122, 'f1_score': 0.6544385376762591}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.94s
Metrics for gamma + ssk + DPID: {'accuracy': 0.35121951219512193, 'f1_score': 0.4754575668145066}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.71s
Metrics for gamma + usm + Lanczos: {'accuracy': 0.5560975609756098, 'f1_score': 0.6535943797983483}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.46s
Metrics for gamma + usm + Lanczos_SAID: {'accuracy': 0.526829268292683, 'f1_score': 0.6287962928329764}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.82s
Metrics for gamma + usm + DPID: {'accuracy': 0.35609756097560974, 'f1_score': 0.4823236249576427}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.66s
Metrics for gamma + swf + Lanczos: {'accuracy': 0.48292682926829267, 'f1_score': 0.5644693029369593}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.63s
Metrics for gamma + swf + Lanczos_SAID: {'accuracy': 0.4878048780487805, 'f1_score': 0.56197294156649}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.64s
Metrics for gamma + swf + DPID: {'accuracy': 0.4926829268292683, 'f1_score': 0.5786977377795168}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.78s
Metrics for CLAHE + Identity + Lanczos: {'accuracy': 0.551219512195122, 'f1_score': 0.6707226052561637}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.44s
Metrics for CLAHE + Identity + Lanczos_SAID: {'accuracy': 0.5463414634146342, 'f1_score': 0.6647243561752949}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.47s
Metrics for CLAHE + Identity + DPID: {'accuracy': 0.44878048780487806, 'f1_score': 0.586357987552045}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.56s
Metrics for CLAHE + ssk + Lanczos: {'accuracy': 0.4292682926829268, 'f1_score': 0.5669924198829516}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.08s
Metrics for CLAHE + ssk + Lanczos_SAID: {'accuracy': 0.44390243902439025, 'f1_score': 0.5781224124613084}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.44s
Metrics for CLAHE + ssk + DPID: {'accuracy': 0.3121951219512195, 'f1_score': 0.4465722898902692}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.52s
Metrics for CLAHE + usm + Lanczos: {'accuracy': 0.44390243902439025, 'f1_score': 0.5741293236007596}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.11s
Metrics for CLAHE + usm + Lanczos_SAID: {'accuracy': 0.4292682926829268, 'f1_score': 0.5620854920553202}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.59s
Metrics for CLAHE + usm + DPID: {'accuracy': 0.28780487804878047, 'f1_score': 0.4227421023270832}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.40s
Metrics for CLAHE + swf + Lanczos: {'accuracy': 0.48292682926829267, 'f1_score': 0.5884499599410512}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.11s
Metrics for CLAHE + swf + Lanczos_SAID: {'accuracy': 0.47804878048780486, 'f1_score': 0.5830608824416874}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.56s
Metrics for CLAHE + swf + DPID: {'accuracy': 0.4585365853658537, 'f1_score': 0.5769705488709841}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.34s
Metrics for SSRetinex + Identity + Lanczos: {'accuracy': 0.3073170731707317, 'f1_score': 0.44895730025774117}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.11s
Metrics for SSRetinex + Identity + Lanczos_SAID: {'accuracy': 0.2926829268292683, 'f1_score': 0.4278131534229095}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.55s
Metrics for SSRetinex + Identity + DPID: {'accuracy': 0.2682926829268293, 'f1_score': 0.40061413956679376}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.35s
Metrics for SSRetinex + ssk + Lanczos: {'accuracy': 0.2780487804878049, 'f1_score': 0.4124036850562751}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.00s
Metrics for SSRetinex + ssk + Lanczos_SAID: {'accuracy': 0.2926829268292683, 'f1_score': 0.4287537699111319}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.66s
Metrics for SSRetinex + ssk + DPID: {'accuracy': 0.23902439024390243, 'f1_score': 0.36723107030636193}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.37s
Metrics for SSRetinex + usm + Lanczos: {'accuracy': 0.2780487804878049, 'f1_score': 0.4047698501569127}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 7.26s
Metrics for SSRetinex + usm + Lanczos_SAID: {'accuracy': 0.2780487804878049, 'f1_score': 0.4069355088177994}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.54s
Metrics for SSRetinex + usm + DPID: {'accuracy': 0.21951219512195122, 'f1_score': 0.3356562137049941}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.23s
Metrics for SSRetinex + swf + Lanczos: {'accuracy': 0.24390243902439024, 'f1_score': 0.3654234310915117}
c:\Users\Austin Chen\Desktop\UCSD\2025_Fall\ECE253\Project\ECE253-food-classification\./src\preprocessing\downscaling.py:226: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=amp_enabled, dtype=torch.float16):
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.90s
Metrics for SSRetinex + swf + Lanczos_SAID: {'accuracy': 0.24390243902439024, 'f1_score': 0.36259667338670515}
Loaded 205 images from ./data/preprocessed
Predicted 205 images in 6.60s
Metrics for SSRetinex + swf + DPID: {'accuracy': 0.22926829268292684, 'f1_score': 0.3521313531401716}
Best combination: ('gamma', 'Identity', 'Lanczos_SAID') with metrics: {'accuracy': 0.5902439024390244, 'f1_score': 0.670396574973935}
Part IV - Model Fine-Tuning¶
In [9]:
model.processor.do_resize = True
# model.processor.do_center_crop = True
from models.trainer import fine_tune
fine_tune(
model_wrapper=model,
data_dir=dataset_path,
val_ratio=0.3,
epochs=20,
batch_size=4,
lr=5e-5
)
Loaded 205 images from ./data/raw Epoch 1/20 - Loss: 1.5862 - Accuracy: 0.6643 Validation accuracy: 0.6774 Epoch 2/20 - Loss: 1.1297 - Accuracy: 0.6923 Validation accuracy: 0.5968 Epoch 3/20 - Loss: 1.0895 - Accuracy: 0.7110 Validation accuracy: 0.8548 Epoch 4/20 - Loss: 0.5649 - Accuracy: 0.7483 Validation accuracy: 0.7742 Epoch 5/20 - Loss: 0.2582 - Accuracy: 0.7832 Validation accuracy: 0.8387 Epoch 6/20 - Loss: 0.1610 - Accuracy: 0.8089 Validation accuracy: 0.8710 Epoch 7/20 - Loss: 0.1333 - Accuracy: 0.8312 Validation accuracy: 0.8710 Epoch 8/20 - Loss: 0.2125 - Accuracy: 0.8453 Validation accuracy: 0.8548 Epoch 9/20 - Loss: 0.0219 - Accuracy: 0.8617 Validation accuracy: 0.9032 Epoch 10/20 - Loss: 0.0125 - Accuracy: 0.8755 Validation accuracy: 0.9032 Epoch 11/20 - Loss: 0.0087 - Accuracy: 0.8868 Validation accuracy: 0.9194 Epoch 12/20 - Loss: 0.0012 - Accuracy: 0.8963 Validation accuracy: 0.9032 Epoch 13/20 - Loss: 0.0009 - Accuracy: 0.9042 Validation accuracy: 0.9032 Epoch 14/20 - Loss: 0.0007 - Accuracy: 0.9111 Validation accuracy: 0.9032 Epoch 15/20 - Loss: 0.0006 - Accuracy: 0.9170 Validation accuracy: 0.9032 Epoch 16/20 - Loss: 0.0006 - Accuracy: 0.9222 Validation accuracy: 0.9032 Epoch 17/20 - Loss: 0.0006 - Accuracy: 0.9268 Validation accuracy: 0.9032 Epoch 18/20 - Loss: 0.0005 - Accuracy: 0.9308 Validation accuracy: 0.9032 Epoch 19/20 - Loss: 0.0005 - Accuracy: 0.9345 Validation accuracy: 0.9032 Epoch 20/20 - Loss: 0.0005 - Accuracy: 0.9378 Validation accuracy: 0.9032